# ---------------------------------------------------------------------
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# ---------------------------------------------------------------------

from __future__ import annotations

from collections.abc import Callable
from enum import Enum

import numpy as np
import torch
from PIL.Image import Image
from qai_hub.client import DatasetEntries
from torch.utils.data import DataLoader
from torchvision.transforms.functional import resize

from qai_hub_models.datasets import DatasetSplit, get_dataset_from_name
from qai_hub_models.models.sam2.model_patches import (
    mask_postprocessing as upscale_masks,
)
from qai_hub_models.utils.base_model import BaseModel, CollectionModel
from qai_hub_models.utils.evaluate import sample_dataset
from qai_hub_models.utils.image_processing import (
    numpy_image_to_torch,
    preprocess_PIL_image,
)
from qai_hub_models.utils.input_spec import InputSpec, get_batch_size
from qai_hub_models.utils.qai_hub_helpers import make_hub_dataset_entries


class SAM2InputImageLayout(Enum):
    RGB = 0
    BGR = 1


class SAM2App:
    """
    This class consists of light-weight "app code" that is required to perform end to end inference with Segment-Anything-2 Model.

    The app uses 2 models:
        * encoder (Given input image, emits image embeddings, high-resolution features to be used by decoder)
        * decoder (image embeddings and high-resolution features --> predicted segmentation masks and scores)
    """

    def __init__(
        self,
        encoder_input_img_size: int,
        mask_threshold: float,
        input_image_channel_layout: SAM2InputImageLayout,
        sam2_encoder: Callable[
            [torch.Tensor, torch.Tensor, torch.Tensor],
            tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
        ],
        sam2_decoder: Callable[
            [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
            tuple[torch.Tensor, torch.Tensor],
        ],
    ) -> None:
        """
        Initializes the segmentation model with encoder and decoder components.

        Parameters
        ----------
            mask_threshold (float): Threshold value used to binarize the predicted masks.

            input_image_channel_layout: SAMInputImageLayout
                Channel layout ("RGB" or "BGR") expected by the encoder.

            sam2_encoder (Callable):
                SAM2 encoder. Must match input & output of each model part generated by qai_hub_models.models.sam2.model.SAM2Encoder
                Takes image to produce the image embeddings and high-resolutions features for sam2 decoder

            sam2_decoder (Callable):
                SAM2 decoder. Must match input and output of qai_hub_models.models.sam2.model.SAM2Decoder.
                Takes image embeddings and high-resolution features to produce segmentation masks and confidence scores.
        """
        self.sam2_encoder = sam2_encoder
        self.sam2_decoder = sam2_decoder
        self.mask_threshold = mask_threshold
        self.encoder_input_img_size = encoder_input_img_size
        self.box = None
        self.normalize_coords = True
        self.mask_input = None
        self.input_image_channel_layout = input_image_channel_layout

    def predict(self, *args, **kwargs):
        return self.predict_mask_from_points(*args, **kwargs)

    def predict_mask_from_points(
        self,
        pixel_values_or_image: torch.Tensor | np.ndarray | Image | list[Image],
        point_coords: torch.Tensor,
        point_labels: torch.Tensor,
        return_logits: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Predicts segmentation masks from input image(s) and point-based prompts.

        This function combines the embedding generation and mask prediction steps.
        It first encodes the input image(s) to obtain image embeddings and high-resolution
        features, then uses the provided point coordinates and labels to predict segmentation masks.

        Parameters
        ----------
            B-Batch_size, N-Num_Points, H-Height, W-Width, C-Channel
            pixel_values_or_image: torch.Tensor
                PIL image
                or
                numpy array (B H W C x uint8) or (H W C x uint8)
                    channel layout consistent with self.input_image_channel_layout
                or
                pyTorch tensor (B C H W x int8, value range is [0, 255])
                    channel layout consistent with self.input_image_channel_layout
            point_coords : Coordinates of points used as prompts (shape: [N, 2] or [B, N, 2]).
            point_labels : Labels for the points (shape: [N] or [B, N]).
            return_logits (bool, optional): If True, return raw logits; otherwise, apply thresholding. Defaults to False.

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - upscaled_masks: torch.Tensor of shape [b, 1, <input image spatial dims>]
                    The predicted segmentation masks, upscaled to original size.
                - scores: torch.Tensor of shape [b, 1]
                    Confidence scores for each predicted mask.

        Where,
            b = number of input images
        """
        (
            image_embeddings,
            high_res_features1,
            high_res_features2,
            sparse_embedding,
            input_images_original_size,
        ) = self.predict_embeddings(
            pixel_values_or_image,
            point_coords,
            point_labels,
        )
        return self.predict_mask_from_points_and_embeddings(
            image_embeddings,
            high_res_features1,
            high_res_features2,
            sparse_embedding,
            input_images_original_size,
            return_logits,
        )

    def predict_embeddings(
        self,
        pixel_values_or_image: torch.Tensor | np.ndarray | Image | list[Image],
        point_coords: torch.Tensor,
        point_labels: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, int]]:
        """
        Generates image embeddings and high-resolution features from input image(s).

        This function processes the input image(s) through a transformation pipeline and
        a series of encoder splits to produce image embeddings and high-resolution features
        required for downstream tasks such as segmentation.

        Parameters
        ----------
            B-Batch_size, N-Num_Points, H-Height, W-Width, C-Channel
            pixel_values_or_image: torch.Tensor
                PIL image
                or
                numpy array (B H W C x uint8) or (H W C x uint8)
                    channel layout consistent with self.input_image_channel_layout
                or
                pyTorch tensor (B C H W x int8, value range is [0, 255])
                    channel layout consistent with self.input_image_channel_layout
            point_coords : Coordinates of points used as prompts (shape: [N, 2] or [B, N, 2]).
            point_labels : Labels for the points (shape: [N] or [B, N]).

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
                - image_embeddings (torch.Tensor) [1,256,64,64]: The image embeddings from the encoder.
                - high_res_features1 (torch.Tensor) [1, 32, 256, 256]: First set of high-resolution features.
                - high_res_features2 (torch.Tensor) [1, 64, 128, 128]: Second set of high-resolution features.
                - sparse_embeddings (torch.Tensor) [1, N+1, 256]: The sparse embeddings from the prompt encoder.
                - input_images_original_size: tuple[int, int]: Original size of input image (BEFORE reshape to fit encoder input size)

        Discussion:
            It is faster to run this once on an image (compared to the entire encoder / decoder pipeline)
            if masks will be predicted several times on the same image.
        """
        # Translate input to torch tensor of shape [N, C, H, W]
        if isinstance(pixel_values_or_image, Image):
            pixel_values_or_image = [pixel_values_or_image]
        if isinstance(pixel_values_or_image, list):
            NCHW_int8_torch_frames = torch.cat(
                [
                    preprocess_PIL_image(
                        x.convert(self.input_image_channel_layout.name), True
                    )
                    for x in pixel_values_or_image
                ]
            )
        elif isinstance(pixel_values_or_image, np.ndarray):
            NCHW_int8_torch_frames = numpy_image_to_torch(pixel_values_or_image, True)
        else:
            NCHW_int8_torch_frames = pixel_values_or_image

        # Resize input image to the encoder's desired input size.
        input_images_original_size = (
            NCHW_int8_torch_frames.shape[2],
            NCHW_int8_torch_frames.shape[3],
        )
        # Run encoder
        image = resize(
            NCHW_int8_torch_frames,
            [self.encoder_input_img_size, self.encoder_input_img_size],
        )

        # Expand point_coords and point_labels to include a batch dimension, if necessary
        if len(point_coords.shape) == 2:
            point_coords = torch.unsqueeze(point_coords, 0)
        if len(point_labels.shape) == 1:
            point_labels = torch.unsqueeze(point_labels, 0)

        h, w = input_images_original_size
        point_coords = point_coords.clone().float()
        point_coords[..., 0] = point_coords[..., 0] / w
        point_coords[..., 1] = point_coords[..., 1] / h

        image_embeddings, high_res_features1, high_res_features2, sparse_embedding = (
            self.sam2_encoder(image, point_coords, point_labels.float())
        )

        return (
            image_embeddings,
            high_res_features1,
            high_res_features2,
            sparse_embedding,
            input_images_original_size,
        )

    def predict_mask_from_points_and_embeddings(
        self,
        image_embeddings: torch.Tensor,
        high_res_features1: torch.Tensor,
        high_res_features2: torch.Tensor,
        sparse_embedding: torch.Tensor,
        input_images_original_size: tuple[int, int],
        return_logits: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Predicts segmentation masks from image embeddings and point-based prompts.

        This function takes image embeddings, high-resolution features, and user-provided
        point coordinates and labels to generate segmentation masks using a decoder.
        The resulting masks are upscaled to the original image size and optionally thresholded.

        Parameters
        ----------
            image_embeddings (torch.Tensor) [1,256,64,64]: The image embeddings from the encoder.
            high_res_features1 (torch.Tensor) [1, 32, 256, 256]: First set of high-resolution features.
            high_res_features2 (torch.Tensor) [1, 64, 128, 128]: Second set of high-resolution features.
            sparse_embeddings (torch.Tensor) [1, N+1, 256]: The sparse embeddings from the prompt encoder.
            input_images_original_size (tuple[int, int]): Original size of the input image (height, width).
            input_images_original_size: tuple[int, int]: Original size of input image (BEFORE reshape to fit encoder input size)
            return_logits (bool): If True, return raw logits; otherwise, apply thresholding.

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - upscaled_masks: torch.Tensor of shape [b, 1, <input image spatial dims>]
                    The predicted segmentation masks, upscaled to original size.
                - scores: torch.Tensor of shape [b, 1]
                    Confidence scores for each predicted mask.

        Where,
            b = number of input images
        """
        # Run decoder
        masks, scores = self.sam2_decoder(
            image_embeddings,
            high_res_features1,
            high_res_features2,
            sparse_embedding,
        )

        # Upscale masks
        upscaled_masks = upscale_masks(masks, input_images_original_size)

        # Apply mask threshold
        if not return_logits:
            upscaled_masks = upscaled_masks > self.mask_threshold

        return upscaled_masks, scores

    @classmethod
    def get_calibration_data(
        cls,
        model: BaseModel,
        calibration_dataset_name: str,
        num_samples: int | None,
        input_spec: InputSpec,
        collection_model: CollectionModel,
    ) -> DatasetEntries:
        batch_size = get_batch_size(input_spec) or 1
        encoder = collection_model.components["SAM2Encoder"]
        assert isinstance(encoder, BaseModel)
        dataset = get_dataset_from_name(
            calibration_dataset_name,
            split=DatasetSplit.TRAIN,
            input_spec=encoder.get_input_spec(),
        )
        num_samples = num_samples or dataset.default_num_calibration_samples()
        num_samples = (num_samples // batch_size) * batch_size
        print(f"Loading {num_samples} calibration samples.")
        torch_dataset = sample_dataset(dataset, num_samples)
        dataloader = DataLoader(torch_dataset, batch_size=batch_size)
        inputs: list[list[torch.Tensor | np.ndarray]] = [
            [] for _ in range(len(input_spec))
        ]
        for sample_input, _ in dataloader:
            if model._get_name() == "SAM2Decoder":
                sample_input = encoder(*sample_input)
            if isinstance(sample_input, (tuple, list)):
                for i, tensor in enumerate(sample_input):
                    inputs[i].append(tensor)
            else:
                inputs[0].append(sample_input)
        return make_hub_dataset_entries(tuple(inputs), list(input_spec.keys()))
